from collections import deque
import random
from matplotlib import cm
import math
from math import sqrt, pi
from datetime import datetime
import re
import torch
import os
import numpy as np
from PIL import Image
import torch.nn as nn
from scipy import ndimage
import matplotlib.pyplot as plt
from model import Semantic_Mapping
import cv2
import logging
from pathlib import Path
import time
from collections import deque, defaultdict
import json
from scipy.ndimage import binary_fill_holes, binary_dilation
from scipy.ndimage import label

import torch
import numpy as np
from math import sqrt, pi


def get_local_map_boundaries(agent_loc, local_sizes, full_sizes, global_downscaling):
    loc_r, loc_c = agent_loc
    local_w, local_h = local_sizes
    full_w, full_h = full_sizes

    if global_downscaling > 1:
        gx1, gy1 = loc_r - local_w // 2, loc_c - local_h // 2
        gx2, gy2 = gx1 + local_w, gy1 + local_h
        if gx1 < 0:
            gx1, gx2 = 0, local_w
        if gx2 > full_w:
            gx1, gx2 = full_w - local_w, full_w

        if gy1 < 0:
            gy1, gy2 = 0, local_h
        if gy2 > full_h:
            gy1, gy2 = full_h - local_h, full_h
    else:
        gx1, gx2, gy1, gy2 = 0, full_w, 0, full_h

    return [gx1, gx2, gy1, gy2]


def init_map_and_pose(full_map, full_pose, map_size_cm, origins, planner_pose_inputs,
                      map_resolution, num_scenes, lmb, local_map, local_pose,
                      device, local_w, local_h, full_w, full_h, global_downscaling):

    full_map.fill_(0.)
    full_pose.fill_(0.)
    full_pose[:, :2] = map_size_cm / 100.0 / 2.0

    locs = full_pose.cpu().numpy()  # locations ，位置
    planner_pose_inputs[:, :3] = locs
    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [int(r * 100.0 / map_resolution),
                        int(c * 100.0 / map_resolution)]

        full_map[e, 2:4, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

        lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                          (local_w, local_h),
                                          (full_w, full_h), global_downscaling)

        planner_pose_inputs[e, 3:] = lmb[e]
        origins[e] = [lmb[e][2] * map_resolution / 100.0,
                      lmb[e][0] * map_resolution / 100.0, 0.]
        # 这行代码从full_map中提取了以智能体当前位置为中心的局部区域,并将其赋值给local_map。这样,
        # local_map的内容被更新为智能体周围的最新信息,但其大小保持不变
    for e in range(num_scenes):

        local_map[e] = full_map[e, :,
                                lmb[e, 0]:lmb[e, 1],
                                lmb[e, 2]:lmb[e, 3]]
        local_pose[e] = full_pose[e] - \
            torch.from_numpy(origins[e]).to(device).float()

    return full_map, full_pose, planner_pose_inputs, origins, local_map, local_pose


def update_full_map_for_env(full_map, local_map, lmb, e):

    full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
        local_map[e]
    return full_map


def init_map_and_pose_for_env(full_map, full_pose, planner_pose_inputs, origins, local_map, local_pose, map_size_cm, map_resolution, local_w, local_h, full_w, full_h, global_downscaling, device, lmb, e):
    full_map[e].fill_(0.)
    full_pose[e].fill_(0.)
    full_pose[e, :2] = map_size_cm / 100.0 / 2.0

    locs = full_pose[e].cpu().numpy()
    planner_pose_inputs[e, :3] = locs
    r, c = locs[1], locs[0]
    loc_r, loc_c = [int(r * 100.0 / map_resolution),
                    int(c * 100.0 / map_resolution)]

    full_map[e, 2:4, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

    lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                      (local_w, local_h),
                                      (full_w, full_h), global_downscaling)

    planner_pose_inputs[e, 3:] = lmb[e]
    origins[e] = [lmb[e][2] * map_resolution / 100.0,
                  lmb[e][0] * map_resolution / 100.0, 0.]

    local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
    local_pose[e] = full_pose[e] - \
        torch.from_numpy(origins[e]).to(device).float()
    return full_map, full_pose, planner_pose_inputs, origins, lmb, local_map, local_pose


def save_image_batch(batch_array, save_dir="policy_vis", img_format='png'):
    """
    将四维图像批次数组保存为图片文件
    参数：
        batch_array : numpy.ndarray (形状为 [batch, height, width, channels])
        save_dir : 目标文件夹路径（字符串）
        img_format : 图像格式（默认png，支持jpg/png/bmp等）
    """
    # 创建目标目录
    os.makedirs(save_dir, exist_ok=True)
    try:
        # 遍历批次中的每张图像
        for i in range(len(batch_array)):
            img_data = batch_array[i]

            # 验证图像数据范围
            if img_data.dtype != np.uint8:
                if img_data.max() <= 1.0:  # 处理归一化数据
                    img_data = (img_data * 255).astype(np.uint8)
                else:
                    img_data = img_data.astype(np.uint8)

            # 创建PIL图像对象（处理可能的通道顺序问题）
            pil_image = Image.fromarray(img_data, mode='RGB')

            # 构建保存路径
            filename = f"image_{i:03d}.{img_format}"
            save_path = os.path.join(save_dir, filename)

            # 保存图像（自动根据后缀选择格式）
            pil_image.save(save_path)
            print(f"成功保存：{save_path}")

    except Exception as e:
        print(f"保存失败，错误类型：{type(e).__name__}，详细信息：{str(e)}")
        raise


def prepare_planner_inputs(full_map, goal_maps, candidate_goal_map, candidate_goal_masks, filled_mask, obstacle_map, planner_pose_inputs,
                           num_scenes, found_goal, visualize, print_images,
                           is_rotatation=[False, None], policy_vis: list = [False, []]):
    """
    准备基于全局地图的规划器输入
    Args:
        full_map: 全局地图 (num_scenes, 20, 480, 480)
        goal_maps: 全局目标地图列表
        planner_pose_inputs: 位姿输入 (num_scenes, 7) - 只使用前3维
        其他参数保持不变
    """
    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        # 使用全局地图
        p_input['map_pred'] = full_map[e, 0, :, :].cpu().numpy()
        p_input['exp_pred'] = full_map[e, 1, :, :].cpu().numpy()

        # 只传递前3维位姿信息（全局坐标系下的 x, y, orientation）
        p_input['pose_pred'] = planner_pose_inputs[e][:3]
        p_input['candidate_goals'] = candidate_goal_map
        p_input['candidate_goal_masks'] = candidate_goal_masks
        p_input['filled_mask'] = filled_mask
        p_input['obstacle_map'] = obstacle_map

        if policy_vis[0]:
            p_input['rgb_vis'] = policy_vis[1]
        else:
            p_input['rgb_vis'] = None

        # # 使用全局目标地图
        # if isinstance(goal_maps, list) and len(goal_maps) > 0:
        #     if hasattr(goal_maps[0], 'sum') and goal_maps[0].sum() == 0:
        #         # 如果需要随机目标，在全局地图上设置
        #         # 480, 480
        #         # full_w, full_h = full_map.shape[2], full_map.shape[3]
        #         # goal_maps = set_random_goal_map(
        #         #     num_scenes, goal_maps, full_w, full_h, e)
        #         p_input['goal'] = goal_maps[e]
        #     else:
        #         p_input['goal'] = goal_maps[e]
        # else:
        p_input['goal'] = goal_maps[e]

        p_input['found_goal'] = found_goal[e]
        p_input['rotate_agent'] = is_rotatation

        if visualize or print_images:
            full_map[e, -1, :, :] = 1e-5
            p_input['sem_map_pred'] = full_map[e,
                                               4:, :, :].argmax(0).cpu().numpy()

    return planner_inputs, full_map


def update_full_and_local_map(full_map, local_map, lmb, planner_pose_inputs, full_pose, local_pose, origins, num_scenes, map_resolution, global_downscaling, local_w, local_h, full_w, full_h, device):
    for e in range(num_scenes):

        full_map = update_full_map_for_env(full_map, local_map, lmb, e)

        full_pose[e] = local_pose[e] + \
            torch.from_numpy(origins[e]).to(device).float()

        locs = full_pose[e].cpu().numpy()
        r, c = locs[1], locs[0]
        loc_r, loc_c = [int(r * 100.0 / map_resolution),
                        int(c * 100.0 / map_resolution)]

        lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                          (local_w, local_h),
                                          (full_w, full_h), global_downscaling)

        planner_pose_inputs[e, 3:] = lmb[e]
        origins[e] = [lmb[e][2] * map_resolution / 100.0,
                      lmb[e][0] * map_resolution / 100.0, 0.]

        local_map[e] = full_map[e, :,
                                lmb[e, 0]:lmb[e, 1],
                                lmb[e, 2]:lmb[e, 3]]
        local_pose[e] = full_pose[e] - \
            torch.from_numpy(origins[e]).to(device).float()
    return full_map, local_map, lmb, planner_pose_inputs, local_pose


def reset_current_loction(local_pose, planner_pose_inputs, origins, local_map, num_scenes, map_resolution):
    locs = local_pose.cpu().numpy()
    planner_pose_inputs[:, :3] = locs + origins
    local_map[:, 2, :, :].fill_(0.)
    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [int(r * 100.0 / map_resolution),
                        int(c * 100.0 / map_resolution)]

        local_map[e, 2:4, loc_r - 1:loc_r + 2,
                  loc_c - 1:loc_c + 2] = 1.
    return planner_pose_inputs, local_map


def whether_seen_the_goal_and_set_goal_map(num_scenes, infos, full_map, goal_maps, found_goal):
    """
    基于全局地图检测和设置目标地图
    """
    # 使用全局地图尺寸创建目标地图
    goal_maps = [np.zeros((full_map.shape[2], full_map.shape[3]))
                 for _ in range(num_scenes)]
    for e in range(num_scenes):
        cn = infos[e]['goal_cat_id'] + 4
        if full_map[e, cn, :, :].sum() != 0.:

            cat_semantic_map = full_map[e, cn, :, :].cpu().numpy()
            cat_semantic_scores = cat_semantic_map
            cat_semantic_scores[cat_semantic_scores > 0] = 1.

            # 找到所有连通区域
            labeled_array, num_features = label(cat_semantic_scores)

            if num_features > 0:
                # 计算每个区域的面积
                region_areas = []
                for i in range(1, num_features + 1):
                    area = np.sum(labeled_array == i)
                    region_areas.append(area)

                # 找到面积最大的区域
                max_area_idx = np.argmax(region_areas) + 1

                # 只保留面积最大的区域
                goal_maps[e] = (labeled_array == max_area_idx).astype(float)
            else:
                goal_maps[e] = cat_semantic_scores

            found_goal[e] = 1

        # if full_map[e, cn, :, :].sum() != 0.:  # seen the goal category
        #     binary_map = full_map[e, cn, :, :].cpu().numpy()
        #     y_coords, x_coords = np.where(binary_map > 0)
        #     if len(y_coords) > 0:
        #         center_y = int(np.round(np.mean(y_coords)))
        #         center_x = int(np.round(np.mean(x_coords)))
        #         goal_maps[e] = np.zeros_like(binary_map)
        #         goal_maps[e][center_y, center_x] = 1
        #         found_goal[e] = 1

    return goal_maps, found_goal


def generate_candidate_goal_map(full_map, num_candidate_goals, start_channel=4):
    candidate_goal_map = np.zeros(
        (num_candidate_goals, full_map.shape[2], full_map.shape[3]))
    for c in range(num_candidate_goals):
        cur_channel = start_channel+c
        if full_map[0, cur_channel, :, :].sum() != 0.:
            binary_map = full_map[0, cur_channel, :, :].cpu().numpy()
            y_coords, x_coords = np.where(binary_map > 0)
            if len(y_coords) > 0:
                center_y = int(np.round(np.mean(y_coords)))
                center_x = int(np.round(np.mean(x_coords)))
                candidate_goal_map[c] = np.zeros_like(binary_map)
                candidate_goal_map[c][center_y, center_x] = 1
    return candidate_goal_map


def set_the_selected_goal_map(num_scenes, full_map):
    """
    基于全局地图设置选定的目标地图
    """
    goal_maps = [np.zeros((full_map.shape[2], full_map.shape[3]))
                 for _ in range(num_scenes)]
    for e in range(num_scenes):
        cn = 14

        if full_map[e, cn, :, :].sum() != 0.:  # the next selected goal point
            binary_map = full_map[e, cn, :, :].cpu().numpy()
            y_coords, x_coords = np.where(binary_map > 0)
            if len(y_coords) > 0:
                center_y = int(np.round(np.mean(y_coords)))
                center_x = int(np.round(np.mean(x_coords)))
                goal_maps[e] = np.zeros_like(binary_map)
                goal_maps[e][center_y, center_x] = 1

    return goal_maps


def set_random_goal_map(num_scenes, goal_maps, local_w, local_h, e=0):
    # random goal
    g_action = torch.randn(num_scenes, 2)*num_scenes

    cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
    global_goals = [[int(action[0] * local_w),
                    int(action[1] * local_h)]
                    for action in cpu_actions]
    global_goals = [[min(x, int(local_w - 1)),
                    min(y, int(local_h - 1))]
                    for x, y in global_goals]
    goal_maps[e][global_goals[e][0], global_goals[e][1]] = 1
    return goal_maps


def update_semantic_map(infos, obs, local_map, local_pose, sem_map_module, num_scenes, device):
    poses = torch.from_numpy(np.asarray(
        [infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)])).float().to(device)
    # update the semantic map
    _, local_map, _, local_pose = sem_map_module(
        obs, poses, local_map, local_pose)
    return local_map, local_pose


def expand_semantic_mask(obs, target_pixels=200, start_channel=4, end_channel=9):
    """
    扩展语义分割掩码中的小目标区域到指定的像素数量

    Args:
        obs: shape为[num_scenes, 20, 120, 160]的tensor，
             前4通道是RGB-D，后16通道是语义分割结果
        target_pixels: 目标像素数量，默认为100
        start_channel: 起始通道索引
        end_channel: 结束通道索引

    Returns:
        处理后的tensor，保持输入相同的shape
    """
    # 创建输出tensor，初始值与输入相同
    output = obs.clone()

    # 只处理语义通道（4-19）
    for scene_idx in range(obs.shape[0]):  # 遍历每个场景
        for channel_idx in range(start_channel, end_channel+1):    # 遍历语义通道
            # 获取当前掩码并转换为numpy数组
            current_mask = obs[scene_idx, channel_idx].cpu().numpy()

            # 计算1的数量
            num_ones = np.sum(current_mask)

            # 如果没有1或者1的数量已经超过目标数量，跳过
            if num_ones == 0 or num_ones >= target_pixels:
                continue

            # 找到所有1的连通区域
            labeled_array, num_features = ndimage.label(current_mask)

            # 对每个连通区域进行处理
            for region_idx in range(1, num_features + 1):
                # 获取当前区域的掩码
                region_mask = (labeled_array == region_idx)

                # 计算当前区域中1的数量
                region_ones = np.sum(region_mask)

                # 计算需要扩展的像素数
                pixels_to_add = target_pixels - region_ones

                # 使用距离变换找到区域的中心
                distance_map = ndimage.distance_transform_edt(region_mask)

                # 创建扩展掩码
                expanded_mask = distance_map > 0
                while np.sum(expanded_mask) < target_pixels:
                    # 逐步扩展区域
                    expanded_mask = ndimage.binary_dilation(expanded_mask)

                # 如果扩展过多，需要稍微收缩
                while np.sum(expanded_mask) > target_pixels:
                    # 找到边缘像素
                    edges = expanded_mask ^ ndimage.binary_erosion(
                        expanded_mask)
                    if np.sum(edges) == 0:
                        break
                    # 随机移除一些边缘像素
                    edge_pixels = np.where(edges)
                    remove_count = int(np.sum(expanded_mask) - target_pixels)
                    if len(edge_pixels[0]) <= remove_count:
                        break
                    remove_indices = np.random.choice(
                        len(edge_pixels[0]),
                        remove_count,
                        replace=False
                    )
                    expanded_mask[
                        edge_pixels[0][remove_indices],
                        edge_pixels[1][remove_indices]
                    ] = False

                # 更新原始掩码
                current_mask = np.logical_or(current_mask, expanded_mask)

            # 将处理后的掩码转回tensor并更新输出
            output[scene_idx, channel_idx] = torch.from_numpy(
                current_mask).to(obs.device)

    return output


def whether_seen_the_goal(obs, infos, num_scenes):
    whether_seen = [False for _ in range(num_scenes)]
    for e in range(num_scenes):
        cn = infos[e]['goal_cat_id'] + 4

        if obs[e, cn, :, :].sum() != 0.:  # seen the goal category
            whether_seen[e] = True
        else:
            whether_seen[e] = False

    return whether_seen


def visualize_and_save(tensor_before, tensor_after, channel_idx=8):
    """可视化并保存扩展前后的结果"""
    # 创建保存目录
    save_dir = "visualization_results"
    os.makedirs(save_dir, exist_ok=True)

    # 创建一个包含两个子图的图像
    plt.figure(figsize=(12, 6))

    # 绘制扩展前的掩码
    plt.subplot(121)
    plt.imshow(tensor_before[0, channel_idx].cpu().numpy(), cmap='gray')
    plt.title('Before Expansion')
    plt.colorbar()

    # 绘制扩展后的掩码
    plt.subplot(122)
    plt.imshow(tensor_after[0, channel_idx].cpu().numpy(), cmap='gray')
    plt.title('After Expansion')
    plt.colorbar()

    # 保存图像
    plt.savefig(os.path.join(save_dir, 'expansion_comparison_running.png'))
    plt.close()
    # 打印统计信息
    ones_before = torch.sum(tensor_before[0, channel_idx]).item()
    ones_after = torch.sum(tensor_after[0, channel_idx]).item()
    print(f"Number of 1s before expansion: {ones_before}")
    print(f"Number of 1s after expansion: {ones_after}")


def extract_number(input_str: str) -> int:
    """
    从文本中提取最后一对大括号内的第一个数字

    参数:
        input_str: 输入字符串，包含类似 {...} 的内容

    返回:
        int: 提取到的第一个正整数
             如果未找到数字，返回 -2
    """
    # 1. 从后往前找到最后一个右大括号
    right_brace = input_str.rfind('}')
    if right_brace == -1:
        return -2

    # 2. 从右大括号位置往前找到对应的左大括号
    left_brace = input_str.rfind('{', 0, right_brace)
    if left_brace == -1:
        return -2

    # 3. 提取大括号内的内容并查找数字
    bracket_content = input_str[left_brace:right_brace + 1]
    number_match = re.search(r'\d+', bracket_content)
    if not number_match:
        return -2

    # 4. 转换为整数并返回
    return int(number_match.group(0))


def add_image_number(rgb, save_images=False):
    """
    在图像左上角添加带红边白底圆框的序号
    参数：
        images: numpy数组列表，每个数组shape为(480,640,4)
        save_images: 是否保存处理后的图片到本地
    返回：
        处理后的numpy数组列表
    """
    processed_images = []

    # 参数配置
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1.2   # 字体大小
    thickness = 2      # 字体粗细
    circle_radius = 25  # 圆形半径
    margin = 15        # 距离边界的留白
    border_thickness = 3  # 红色边框粗细

    if save_images:
        os.makedirs("numbered_images", exist_ok=True)

    for idx, img in enumerate(rgb):
        h, w = img.shape[:2]
        # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)  # 转换颜色通道顺序
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        # 圆心坐标 (确保不超出边界)
        center_x = margin + circle_radius
        center_y = margin + circle_radius

        # 绘制白色填充圆
        cv2.circle(img,
                   (center_x, center_y),
                   circle_radius,
                   (255, 255, 255),  # 白色填充
                   -1)  # -1表示填充

        # 绘制红色边框圆
        cv2.circle(img,
                   (center_x, center_y),
                   circle_radius,
                   (0, 0, 255),  # 红色边框
                   border_thickness)

        # 计算文字位置（居中显示）
        text = str(idx + 1)
        (text_w, text_h), baseline = cv2.getTextSize(
            text, font, font_scale, thickness)

        # 文字原点坐标（居中算法）
        text_x = center_x - text_w // 2
        text_y = center_y + text_h // 2

        # 绘制文字
        cv2.putText(img, text,
                    (text_x, text_y),
                    font,
                    font_scale,
                    (0, 0, 0),  # 黑色文字
                    thickness,
                    cv2.LINE_AA)
        img_copy = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        processed_images.append(img_copy)

        if save_images:
            save_path = f"numbered_images/image_{idx+1}.jpg"
            # cv2.imwrite(save_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
            cv2.imwrite(save_path, img)

    return processed_images


def select_view_at_beginning(vlm, history_rgb, infos, logger, save_images=True) -> int:
    numeric_rotate_views = add_image_number(
        history_rgb, save_images=save_images)
    view_response = vlm.ask_multiple_images(
        numeric_rotate_views, infos[0]['goal_name'])
    log = " ".join(
        f"[Ep {infos[0]['episode_no']:03d} | Step {infos[0]['time']+1:03d}]\n")
    log += view_response
    log += "\n"
    log += "--"*40
    logger.info(log)
    view_num = extract_number(view_response)

    return view_num


def exp_goal_based_cur_view(rgbd, vlm, trav, infos, logger) -> torch.Tensor:
    # 检查run_current_view的返回值
    success = trav.run_current_view(
        rgb=rgbd[:, :, :3], depth=rgbd[:, :, 3])

    if not success:
        # 处理没有可通行区域的情况
        log = " ".join(
            f"[Ep {infos[0]['episode_no']:03d} | Step {infos[0]['time']+1:03d}]\n")
        log += "WARNING: No traversable area found in current view. Skipping VLM processing."
        log += "\n"
        log += "--"*40
        logger.warning(log)
        return None, None

    rgb_annotated_np = trav.get_numpy_of_annotated_candidate_exp_goals()

    # 额外检查rgb_annotated_np是否为None
    if rgb_annotated_np is None:
        log = " ".join(
            f"[Ep {infos[0]['episode_no']:03d} | Step {infos[0]['time']+1:03d}]\n")
        log += "ERROR: Failed to generate annotated image despite successful traversable area detection."
        log += "\n"
        log += "--"*40
        logger.error(log)
        return None, None

    action_response = vlm.ask_single_image(
        rgb_annotated_np, infos[0]['goal_name'])
    log = " ".join(
        f"[Ep {infos[0]['episode_no']:03d} | Step {infos[0]['time']+1:03d}]\n")
    log += action_response
    log += "\n"
    log += "--"*40
    logger.info(log)
    action = extract_number(action_response)
    # action = 2
    action_highlighted_np = trav.highlight_selected_goal(
        selected_goal=action)
    save_image_batch([action_highlighted_np])  # 保存
    # 获取选中的目标的掩码
    goal_mask = trav.get_goal_mask_of_select_point(action)
    return goal_mask, action_highlighted_np


def set_base(args):
    num_scenes = args.num_processes
    num_episodes = int(args.num_eval_episodes)
    device = args.device = torch.device("cuda" if args.cuda else "cpu")

    finished = np.zeros((args.num_processes))
    # Initialize map variables:
    # Full map consists of multiple channels containing the following:
    # 1. Obstacle Map
    # 2. Exploread Area
    # 3. Current Agent Location
    # 4. Past Agent Locations
    # 5,6,7,.. : Semantic Categories
    nc = args.num_sem_categories + 4  # num channels 语义地图的通道数

    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution  # 480
    full_w, full_h = map_size, map_size
    local_w = int(full_w / args.global_downscaling)  # 240
    local_h = int(full_h / args.global_downscaling)  # 240

    # Initializing full and local map
    full_map = torch.zeros(num_scenes, nc, full_w, full_h).float().to(
        device)  # full_map.shape=torch.Size([num_process, 20, 480, 480])
    local_map = torch.zeros(num_scenes, nc, local_w,
                            local_h).float().to(device)  # local_map.shape =torch.Size([num_process, 20, 240, 240])

    # Initial full and local pose
    # tensor([[0., 0., 0.],[0., 0., 0.]], device='cuda:0')
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    # tensor([[0., 0., 0.],[0., 0., 0.]], device='cuda:0')
    local_pose = torch.zeros(num_scenes, 3).float().to(device)

    # Origin of local map
    origins = np.zeros((num_scenes, 3))

    # Local Map Boundaries-->lmb
    lmb = np.zeros((num_scenes, 4)).astype(int)

    # Planner pose inputs has 7 dimensions
    # 1-3 store continuous global agent location
    # 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))
    return full_w, full_h, local_w, local_h, num_scenes, num_episodes, device, finished, full_map, local_map, full_pose, local_pose, origins, lmb, planner_pose_inputs


def set_more(args, num_scenes, full_w, full_h, num_episodes):
    device = args.device = torch.device("cuda" if args.cuda else "cpu")
    # Global policy observation space
    # ngc=24 number of global channel
    # Semantic Mapping
    # Semantic_Mapping return fp_map_pred, map_pred, pose_pred, current_poses
    sem_map_module = Semantic_Mapping(args).to(device)
    sem_map_module.eval()
    found_goal = [0 for _ in range(num_scenes)]
    goal_maps = [np.zeros((full_w, full_h))
                 for _ in range(num_scenes)]

    torch.set_grad_enabled(False)
    spl_per_category = defaultdict(list)
    success_per_category = defaultdict(list)
    start = time.time()
    dones = [False for _ in range(num_scenes)]

    # obs, infos, rgbd = envs.reset()
    # pre_obs_pose = [[] for _ in range(num_scenes)]

    policy_vis = [False, []]
    dones = [False for _ in range(num_scenes)]
    rotatation_over_at_beginning = [
        False for _ in range(num_episodes)]
    new_goal_required = [False for _ in range(num_scenes)]
    rotation_history = [{
        'rgbd': None,         # 当前视角的RGBD图像
        'local_pose': None,   # 当前的局部位姿
        'obs': None,          # 观察数据
        'sensor_pose': None,  # 传感器位姿
    } for _ in range(12)]
    return rotation_history, rotatation_over_at_beginning, new_goal_required, \
        dones, policy_vis, dones, \
        sem_map_module, found_goal,  goal_maps, \
        spl_per_category, success_per_category, start


def set_doneInfo(args):
    if args.eval:
        episode_success = []
        episode_spl = []
        episode_dist = []
        for _ in range(args.num_processes):
            episode_success.append(deque(maxlen=args.num_eval_episodes))
            episode_spl.append(deque(maxlen=args.num_eval_episodes))
            episode_dist.append(deque(maxlen=args.num_eval_episodes))
    # else分支注释掉，因为使用场景中args.eval必然为1
    # else:
    #     episode_success = None
    #     episode_spl = None
    #     episode_dist = None

    spl_per_category = defaultdict(list)
    success_per_category = defaultdict(list)
    return episode_success, episode_spl, episode_dist, spl_per_category, success_per_category


def set_logger(args):
    # Setup Logging
    dump_dir = Path("{}/dump/{}/".format(args.dump_location, args.exp_name))
    log_dir = Path("{}/logs/{}/".format(args.dump_location, args.exp_name))

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(dump_dir):
        os.makedirs(dump_dir)

    # 配置 Logger
    logger = logging.getLogger('navigation')
    logger.setLevel(logging.INFO)

    handler = logging.FileHandler(log_dir / 'navigation.log')
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    return logger


def log_interval(logger, infos, args, num_scenes, start, episode_success, episode_spl, episode_dist):
    step = infos[0]['time']

    end = time.time()
    time_elapsed = time.gmtime(end - start)
    log = " ".join([
        "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
        "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
        "num timesteps {},".format(step * num_scenes),
        "FPS {},".format(int(step * num_scenes / (end - start)))
    ])

    if args.eval:
        total_success = []
        total_spl = []
        total_dist = []
        for e in range(args.num_processes):
            for acc in episode_success[e]:
                total_success.append(acc)
            for dist in episode_dist[e]:
                total_dist.append(dist)
            for spl in episode_spl[e]:
                total_spl.append(spl)

        if len(total_spl) > 0:
            log += " ObjectNav succ/spl/dtg:"
            log += " {:.3f}/{:.3f}/{:.3f}({:.0f}),".format(
                np.mean(total_success),
                np.mean(total_spl),
                np.mean(total_dist),
                len(total_spl))

        print(log)
        log += "\n"
        log += "--"*40
        logger.info(log)


def obtain_final_res(args, logger, episode_success, episode_spl, episode_dist, spl_per_category, success_per_category):
    if args.eval:
        print("Dumping eval details...")
        total_success = []
        total_spl = []
        total_dist = []
        log = ""
        for e in range(args.num_processes):
            for acc in episode_success[e]:
                total_success.append(acc)
            for dist in episode_dist[e]:
                total_dist.append(dist)
            for spl in episode_spl[e]:
                total_spl.append(spl)

        if len(total_spl) > 0:
            log = "Final ObjectNav succ/spl/dtg:"
            log += " {:.3f}/{:.3f}/{:.3f}({:.0f}),".format(
                np.mean(total_success),
                np.mean(total_spl),
                np.mean(total_dist),
                len(total_spl))

        print(log)
        log += "\n"
        log += "--"*40
        logger.info(log)

        # Save the spl per category
        log = "Success | SPL per category\n"
        for key in success_per_category:
            log += "{}: {} | {}\n".format(key,
                                          sum(success_per_category[key]) /
                                          len(success_per_category[key]),
                                          sum(spl_per_category[key]) /
                                          len(spl_per_category[key]))

        print(log)
        log += "\n"
        log += "--"*40
        logger.info(log)
        dump_dir = Path(
            "{}/dump/{}/".format(args.dump_location, args.exp_name))
        with open('{}/{}_spl_per_cat_pred_thr.json'.format(
                dump_dir, args.split), 'w') as f:
            json.dump(spl_per_category, f)

        with open('{}/{}_success_per_cat_pred_thr.json'.format(
                dump_dir, args.split), 'w') as f:
            json.dump(success_per_category, f)


def replace_batch_slices(tmp_goals_masks, obs, start_index=6):
    """
    用n1数组的内容替换n2数组中从指定起始索引开始的batch维度

    参数:
    n1: numpy数组, 形状为 (batch1, height, width)
    n2: numpy数组, 形状为 (batch2, height, width)
    start_index: 替换的起始batch索引(默认为6)

    返回:
    修改后的n2数组副本

    异常:
    如果n2的batch数不足以容纳替换操作，抛出ValueError
    """
    # 获取维度信息
    batch1 = tmp_goals_masks.shape[0]  # tmp_goals_masks.shape=(num,120,160)
    batch2 = obs.shape[1]
    h1, w1 = tmp_goals_masks.shape[1:]
    h2, w2 = obs.shape[2:]  # obs.shape=(1, 20, 120,160)

    # 验证维度匹配
    if h1 != h2 or w1 != w2:
        raise ValueError(f"空间维度不匹配: n1({h1}x{w1}) vs n2({h2}x{w2})")

    # 验证batch数量是否足够
    if start_index + batch1 > batch2:
        required = start_index + batch1
        raise ValueError(
            f"n2的batch数({batch2})不足，需要至少{required}个batch。"
            f"当前batch1大小为{batch1}，最小替换索引为{start_index}"
        )

    # 创建n2的副本（避免修改原始数据）
    result = obs.clone()

    # 执行切片替换[6,7](@ref)
    replace_slice = slice(start_index, start_index + batch1)
    result[replace_slice] = torch.from_numpy(tmp_goals_masks)

    return result


def generate_centroid_masks(full_map, min_index, max_index):
    """
    生成通道非零区域质心的二值掩码

    参数:
    full_map: numpy数组, 形状为(1, 20, height, width)
    min_index: 起始通道索引(包含)
    max_index: 结束通道索引(包含)

    返回:
    masks: numpy数组, 形状为(n_slices, height, width), dtype=np.uint8

    异常:
    若索引超出范围或min_index > max_index时抛出ValueError
    """
    # 验证输入合法性
    if min_index < 0 or max_index >= full_map.shape[1]:
        raise ValueError(f"索引超出范围. 有效通道范围: [0, {full_map.shape[1]-1}]")
    if min_index > max_index:
        raise ValueError("min_index不能大于max_index")

    # 提取目标通道数据 (n_slices, height, width)
    channels = full_map[0, min_index:max_index+1]
    n_slices, height, width = channels.shape

    # 初始化输出掩码 (全零)
    masks = np.zeros((n_slices, height, width), dtype=np.uint8)

    for i in range(n_slices):
        channel_data = channels[i]

        # 检查是否全零通道
        if np.all(channel_data == 0):
            continue  # 保持全零

        # 获取非零像素坐标
        y_coords, x_coords = np.nonzero(channel_data)

        # 计算质心坐标 (四舍五入取整)
        y_center = int(np.round(np.mean(y_coords)))
        x_center = int(np.round(np.mean(x_coords)))

        # 在质心位置设置1
        masks[i, y_center, x_center] = 1

    return masks


def get_agent_position_mask(planner_pose_inputs, full_map_shape, args, env_idx=0):
    """
    获取agent在full_map中位置的二值掩码

    Args:
        planner_pose_inputs: 规划器姿态输入 (num_scenes, 7)
        full_map: 全局地图 (num_scenes, channels, height, width)
        args: 参数对象，包含map_resolution
        env_idx: 环境索引，默认为0

    Returns:
        numpy.ndarray: 与full_map相同大小的二值掩码 (height, width)，agent位置为1，其他为0
    """
    # 获取agent全局位置坐标
    global_x = planner_pose_inputs[env_idx, 0]
    global_y = planner_pose_inputs[env_idx, 1]

    # 获取地图尺寸
    map_height, map_width = full_map_shape[2], full_map_shape[3]  # 通常是480x480

    # 转换为地图像素坐标
    pixel_row = int(global_y * 100.0 / args.map_resolution)
    pixel_col = int(global_x * 100.0 / args.map_resolution)

    # 创建零掩码
    agent_mask = np.zeros((map_height, map_width), dtype=np.uint8)

    # 边界检查并设置agent位置
    if 0 <= pixel_row < map_height and 0 <= pixel_col < map_width:
        agent_mask[pixel_row, pixel_col] = 1

    return agent_mask


def expand_masks(original_masks, target_area=10000):
    """
    为每个通道扩展圆形区域

    参数:
        original_masks: 原始掩码数组，形状为 (num, height, width)
        target_area: 目标面积值 (整数)

    返回:
        扩展后的掩码数组，形状为 (num, height, width)
    """
    if original_masks.shape[0] == 0:
        return None
    # 复制原始数组，不修改原始数据
    expanded_masks = original_masks.copy()
    num, height, width = expanded_masks.shape

    # 验证参数
    if target_area <= 1:
        raise ValueError("目标面积值必须大于1")

    # 计算每个通道中圆心位置
    centers = []
    for i in range(num):
        # 找到值为1的位置
        positions = np.argwhere(expanded_masks[i] == 1)
        if len(positions) == 0:
            raise ValueError(f"第 {i} 个通道中没有找到值为1的位置")

        # 取第一个位置作为圆心
        y, x = positions[0]
        centers.append((y, x))

    # 处理每个通道
    for i in range(num):
        cy, cx = centers[i]  # 圆心位置

        # 计算最大可能半径（受边界限制）
        top_dist = cy
        bottom_dist = height - 1 - cy
        left_dist = cx
        right_dist = width - 1 - cx
        max_r = min(top_dist, bottom_dist, left_dist, right_dist)

        # 计算目标半径（理论上，不考虑边界）
        target_radius = sqrt(target_area / pi)

        # 计算实际可行的半径
        if max_r >= target_radius:
            # 如果图像范围内可容纳目标半径，则使用目标半径
            final_radius = target_radius
        else:
            # 如果目标半径大于最大允许半径，则使用最大允许半径
            final_radius = max_r

        # 创建网格坐标
        Y, X = np.ogrid[:height, :width]

        # 计算每个点到圆心的距离
        dist = np.sqrt((X - cx)**2 + (Y - cy)**2)

        # 创建圆形区域掩码
        circle_mask = dist <= final_radius

        # 更新当前通道的掩码
        expanded_masks[i] = circle_mask.astype(np.uint8)

    return expanded_masks


# def visualize_obs_exp_agent_masks(obstacle_mask, exp_mask, agent_position_mask, output_path="./visualize_obs_exp_agent_masks.png", dpi=300):
#     """
#     可视化三个二值掩码

#     参数:
#         mask1: 第一个掩码 (height, width), 1=浅灰, 0=白
#         mask2: 第二个掩码 (height, width), 1=浅灰, 0=白
#         mask3: 第三个掩码 (height, width), 1=红, 0=白
#         output_path: 保存路径 (None则不保存)
#         dpi: 图像分辨率

#     返回:
#         matplotlib Figure对象
#     """
#     # 检查输入形状是否一致
#     if not (mask1.shape == mask2.shape == mask3.shape):
#         raise ValueError("所有掩码必须具有相同的形状")

#     height, width = mask1.shape
#     visualized = np.ones((height, width, 3))  # 白色背景 (RGB)

#     # 定义颜色
#     light_gray = [0.8, 0.8, 0.8]  # 浅灰色
#     red = [1.0, 0.0, 0.0]        # 红色

#     # 处理第一个掩码 (浅灰)
#     visualized[mask1 == 1] = light_gray

#     # 处理第二个掩码 (浅灰)
#     visualized[mask2 == 1] = light_gray

#     # 处理第三个掩码 (红色)
#     visualized[mask3 == 1] = red

#     # 创建图像
#     fig = plt.figure(figsize=(10, 10))
#     plt.imshow(visualized)
#     plt.axis('off')

#     # 保存图像
#     if output_path is not None:
#         plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=dpi)
#         print(f"可视化结果已保存到 {output_path}")

#     return fig


# def get_agent_position_mask(planner_pose_inputs, full_map, args, env_idx=0):
#     """
#     获取agent在full_map中位置的二值掩码

#     Args:
#         planner_pose_inputs: 规划器姿态输入 (num_scenes, 7)
#         full_map: 全局地图 (num_scenes, channels, height, width)
#         args: 参数对象，包含map_resolution
#         env_idx: 环境索引，默认为0

#     Returns:
#         numpy.ndarray: 与full_map相同大小的二值掩码 (height, width)，agent位置为1，其他为0
#     """
#     # 获取agent全局位置坐标
#     global_x = planner_pose_inputs[env_idx, 0]
#     global_y = planner_pose_inputs[env_idx, 1]

#     # 获取地图尺寸
#     map_height, map_width = full_map.shape[2], full_map.shape[3]  # 通常是480x480

#     # 转换为地图像素坐标
#     pixel_row = int(global_y * 100.0 / args.map_resolution)
#     pixel_col = int(global_x * 100.0 / args.map_resolution)

#     # 创建零掩码
#     agent_mask = np.zeros((map_height, map_width), dtype=np.uint8)

#     # 边界检查并设置agent位置
#     if 0 <= pixel_row < map_height and 0 <= pixel_col < map_width:
#         agent_mask[pixel_row, pixel_col] = 1

#     return agent_mask


# def fill_agent_position_mask(mask1, mask2, agent_position_mask):
#     """合并掩码并填充智能体周围区域"""
#     # 1. 合并两个掩码
#     combined_mask = np.logical_or(mask1, mask2).astype(np.uint8)

#     # 2. 定位智能体位置
#     agent_pos = np.argwhere(agent_position_mask == 1)
#     if len(agent_pos) == 0:
#         raise ValueError("未在agent_position_mask中找到值为1的位置")
#     y, x = agent_pos[0]

#     # 3. 验证该位置为0（未被占用）
#     if combined_mask[y, x] != 0:
#         raise ValueError("智能体位置在combined_mask中已被占用")

#     # 4. 创建填充掩码（确保不修改原图）
#     filled_mask = combined_mask.copy()

#     # 5. 填充连通区域
#     # 创建临时图像（添加边界防止轮廓越界）
#     temp_mask = np.zeros(
#         (filled_mask.shape[0]+2, filled_mask.shape[1]+2), dtype=np.uint8)
#     temp_mask[1:-1, 1:-1] = filled_mask

#     # 执行泛洪填充（从智能体位置开始填充0区域）
#     cv2.floodFill(
#         image=temp_mask,
#         mask=None,
#         seedPoint=(x+1, y+1),  # 调整边界偏移
#         newVal=1,
#         loDiff=0,
#         upDiff=0,
#         flags=4 | (1 << 8)  # 4连通+固定范围填充
#     )

#     # 移除临时边界
#     filled_mask = temp_mask[1:-1, 1:-1]

#     return combined_mask, filled_mask


def fill_agent_position_mask(mask1, mask2, agent_position_mask):
    """
    处理三个掩码的复杂逻辑

    参数:
        mask1: 第一个二值掩码 (height, width)
        mask2: 第二个二值掩码 (height, width)
        agent_position_mask: 智能体位置掩码 (height, width)

    返回:
        处理后的combined_mask (height, width)
    """
    # 1. 确保所有输入都在CPU上并转换为numpy数组
    if torch.is_tensor(mask1):
        mask1 = mask1.cpu().numpy()
    if torch.is_tensor(mask2):
        mask2 = mask2.cpu().numpy()
    if torch.is_tensor(agent_position_mask):
        agent_position_mask = agent_position_mask.cpu().numpy()

    # 2. 合并两个掩码
    combined_mask = np.logical_or(mask1, mask2).astype(np.uint8)

    # 3. 找到智能体位置
    agent_pos = np.argwhere(agent_position_mask == 1)
    if len(agent_pos) == 0:
        raise ValueError("agent_position_mask中没有找到1")
    agent_y, agent_x = agent_pos[0]  # 取第一个1的位置

    # 检查智能体位置是否在combined_mask中为0
    if combined_mask[agent_y, agent_x] != 0:
        raise ValueError("智能体位置在combined_mask中不为0")

    # 4. 创建区域掩码 (从智能体位置开始填充)
    region_mask = np.zeros_like(combined_mask)
    stack = [(agent_y, agent_x)]

    # 5. 使用洪水填充算法找到连通区域
    while stack:
        y, x = stack.pop()
        if y < 0 or y >= region_mask.shape[0] or x < 0 or x >= region_mask.shape[1]:
            continue
        if combined_mask[y, x] != 0 or region_mask[y, x] == 1:
            continue

        region_mask[y, x] = 1

        # 添加相邻像素
        stack.append((y+1, x))
        stack.append((y-1, x))
        stack.append((y, x+1))
        stack.append((y, x-1))

    # 6. 填充区域到combined_mask
    combined_mask[region_mask == 1] = 1

    return combined_mask


def visualize_masks(original, filled):
    """可视化掩码对比"""
    original = original.cpu().numpy()
    # filled = filled.cpu().numpy()
    plt.figure(figsize=(12, 6))

    plt.subplot(121)
    plt.imshow(original, cmap='gray')
    plt.title("原始合并掩码")
    plt.axis('off')

    plt.subplot(122)
    plt.imshow(filled, cmap='gray')
    plt.title("填充后掩码")
    plt.axis('off')

    plt.tight_layout()
    plt.savefig("./tmp/visualize_filled_masks.png")
    plt.close()


def filter_candidate_goals(filled_mask, candidate_goals, obstacle_mask, exp_mask, agent_position_mask, k, threshold, dis_thresh, obstacle_radius=10, obstacle_threshold=0.3, trajectory_mask=None, trajectory_radius=10):
    """
    过滤候选目标通道：基于圆形区域内0值占比，并确保目标点不在障碍物上且在已探索区域内，同时过滤周围障碍物过多的点，以及距离机器人太近的点，以及避免重复访问已走过的轨迹区域
    :param filled_mask: 二值掩码 (H, W)
    :param obstacle_mask: 障碍物掩码 (H, W)
    :param exp_mask: 已探索区域掩码 (H, W)
    :param robot_mask: 机器人位置掩码 (H, W)，机器人在的位置是1，其他位置都是0
    :param candidate_goals: 候选目标 (N, H, W)
    :param k: 圆形区域半径（像素）
    :param threshold: 0值占比阈值（0.0-1.0）
    :param dis_thresh: 距离机器人的最小距离阈值（像素）
    :param obstacle_radius: 检查障碍物的圆形半径（像素），默认8
    :param obstacle_threshold: 障碍物占比阈值（0.0-1.0），默认0.5
    :param trajectory_mask: agent历史轨迹掩码 (H, W)，1表示曾经到过的位置，默认None
    :param trajectory_radius: 轨迹检测圆形半径（像素），默认None
    :return: 过滤后的候选目标数组
    """
    # 获取机器人位置
    robot_positions = np.argwhere(agent_position_mask == 1)
    if len(robot_positions) == 0:
        raise ValueError("agent_position_mask中没有找到1")
        # 如果没有找到机器人位置，返回空数组
        # return np.empty((0, *filled_mask.shape))

    # 假设只有一个机器人位置，取第一个
    robot_y, robot_x = robot_positions[0]

    # 初始化保留通道的列表
    valid_channels = []

    # 遍历每个候选通道
    for i in range(candidate_goals.shape[0]):
        channel = candidate_goals[i]
        # 检查通道是否包含目标点
        if np.any(channel == 1):
            # 获取目标点坐标 (y, x)
            y, x = np.argwhere(channel == 1)[0]

            # 检查目标点是否在障碍物上或未探索区域
            if obstacle_mask[y, x] == 1 or exp_mask[y, x] == 0:
                continue

            # 检查目标点周围的障碍物密度
            y_grid, x_grid = np.indices(obstacle_mask.shape)
            obstacle_dist_mask = (
                (y_grid - y)**2 + (x_grid - x)**2) <= obstacle_radius**2

            # 计算圆形区域内障碍物占比
            obstacle_circle_area = obstacle_dist_mask.sum()
            obstacle_count = np.logical_and(
                obstacle_mask.cpu().numpy(), obstacle_dist_mask).sum()
            obstacle_ratio = obstacle_count / \
                obstacle_circle_area if obstacle_circle_area > 0 else 0

            # 如果障碍物占比超过阈值，过滤掉这个点
            if obstacle_ratio > obstacle_threshold:
                continue

            # 生成圆形掩码
            y_grid, x_grid = np.indices(filled_mask.shape)
            dist_mask = ((y_grid - y)**2 + (x_grid - x)**2) <= k**2

            # 计算圆形区域内0值占比
            circle_area = dist_mask.sum()
            zero_count = np.logical_and(filled_mask == 0, dist_mask).sum()
            zero_ratio = zero_count / circle_area if circle_area > 0 else 0

            # 根据阈值决定是否保留
            if zero_ratio > threshold:
                # 检查与机器人的距离
                robot_distance = np.sqrt((y - robot_y)**2 + (x - robot_x)**2)

                # 如果距离大于等于阈值，才进行下一步检查
                if robot_distance >= dis_thresh:
                    # 最后过滤：检查目标点周围是否有agent历史轨迹
                    if trajectory_mask is not None and trajectory_radius is not None:
                        # 生成轨迹检测圆形掩码
                        trajectory_dist_mask = (
                            (y_grid - y)**2 + (x_grid - x)**2) <= trajectory_radius**2

                        # 检查圆形区域内是否包含历史轨迹
                        trajectory_in_circle = np.logical_and(
                            trajectory_mask, trajectory_dist_mask)

                        # 如果圆内有历史轨迹，过滤掉该点
                        if np.any(trajectory_in_circle):
                            continue

                    # 通过所有过滤条件，保留这个候选目标
                    valid_channels.append(channel)

    # 堆叠保留的通道
    return np.stack(valid_channels) if valid_channels else np.empty((0, *filled_mask.shape))


def sort_candidate_goals_by_distance(candidate_goals, agent_position):
    """
    根据候选目标与智能体位置的距离对通道排序（升序）

    :param candidate_goals: 形状 (num_candidate, height, width)，每个通道仅含一个1
    :param agent_position: 形状 (height, width)，仅含一个1
    :return: 按距离排序后的候选目标数组
    """
    # 1. 定位智能体位置
    agent_pos = np.argwhere(agent_position == 1)[0]  # 格式 [y, x]

    # 2. 计算每个候选目标的位置和距离
    distances = []
    for i in range(candidate_goals.shape[0]):
        # 获取当前通道中1的坐标
        goal_pos = np.argwhere(candidate_goals[i] == 1)[0]  # 格式 [y, x]
        # 计算欧氏距离（平方和开方）
        distance = np.sqrt((goal_pos[0] - agent_pos[0])**2 +
                           (goal_pos[1] - agent_pos[1])**2)
        distances.append((i, distance, goal_pos))

    # 3. 按距离升序排序（使用稳定排序保持同距离元素的原始顺序）
    sorted_indices = np.argsort([d[1] for d in distances], kind='stable')

    # 4. 重构排序后的候选目标数组
    sorted_goals = np.zeros_like(candidate_goals)
    for new_idx, old_idx in enumerate(sorted_indices):
        sorted_goals[new_idx] = candidate_goals[distances[old_idx][0]]

    return sorted_goals


# def get_final_candidate_goal_map(full_map, candidate_goal_map_list_or_numpy, planner_pose_inputs, args):
#     obstacle_mask = full_map[0, 0, :, :] > 0
#     exp_mask = full_map[0, 1, :, :] > 0
#     agent_position_mask = get_agent_position_mask(
#         planner_pose_inputs, full_map.shape, args)
#     #
#     combined_mask = np.logical_or(
#         obstacle_mask.cpu().numpy(), exp_mask.cpu().numpy()).astype(np.uint8)
#     # filled_mask = fill_agent_position_mask(
#     #     obstacle_mask, exp_mask, agent_position_mask)
#     if type(candidate_goal_map_list_or_numpy) == list:
#         candidate_goal_maps = np.concatenate(
#             candidate_goal_map_list_or_numpy, axis=0)
#     else:
#         candidate_goal_maps = candidate_goal_map_list_or_numpy

#     filtered_candidate_goal_map = filter_candidate_goals(
#         combined_mask, candidate_goal_maps, obstacle_mask, exp_mask, agent_position_mask, k=8, threshold=0.2, obstacle_radius=8, dis_thresh=5, obstacle_threshold=0.5)
#     if filtered_candidate_goal_map.shape[0] == 0:
#         return combined_mask, obstacle_mask, None
#     sorted_candidate_goal_map = sort_candidate_goals_by_distance(
#         filtered_candidate_goal_map, agent_position_mask)
#     return combined_mask, obstacle_mask, sorted_candidate_goal_map


def update_final_condidate_goal_map(full_map, candidate_goal_map_list_or_numpy, NextGoalIterator, planner_pose_inputs, args, trajectory_mask=None):
    if (type(candidate_goal_map_list_or_numpy) == list) and len(candidate_goal_map_list_or_numpy):
        new = np.concatenate(candidate_goal_map_list_or_numpy, axis=0)
        if NextGoalIterator is not None:
            left = NextGoalIterator.get_remaining_goals()
            if left is not None:
                all_new_candidate_goal_map = np.concatenate(
                    (left, new), axis=0)
            else:
                all_new_candidate_goal_map = new
        else:
            all_new_candidate_goal_map = new
    else:
        all_new_candidate_goal_map = candidate_goal_map_list_or_numpy
    agent_position_mask = get_agent_position_mask(
        planner_pose_inputs, full_map.shape, args)
    obstacle_mask = full_map[0, 0, :, :] > 0
    exp_mask = full_map[0, 1, :, :] > 0
    combined_mask = np.logical_or(
        obstacle_mask.cpu().numpy(), exp_mask.cpu().numpy()).astype(np.uint8)
    filtered_candidate_goal_map = filter_candidate_goals(
        combined_mask, all_new_candidate_goal_map, obstacle_mask, exp_mask, agent_position_mask, k=8, threshold=0.2, obstacle_radius=8, dis_thresh=5, obstacle_threshold=0.5, trajectory_mask=trajectory_mask)
    sorted_candidate_goal_map = sort_candidate_goals_by_distance(
        filtered_candidate_goal_map, agent_position_mask)
    return sorted_candidate_goal_map


def visualize_goals(filled_mask, candidate_goals, obstacle_map, k, threshold, output_path="result.png"):
    """
    综合可视化函数
    :param filled_mask: 二值掩码 (H, W)
    :param candidate_goals: 候选目标点 (N, H, W)
    :param obstacle_map: 障碍物地图 (H, W)
    :param k: 统计半径
    :param threshold: 百分比阈值
    :param output_path: 输出图像路径
    :return: 可视化图像数组
    """
    # 1. 创建基础RGB画布
    h, w = filled_mask.shape
    base_img = np.zeros((h, w, 3), dtype=np.uint8) + 255  # 白色背景

    # 2. 填充基础区域颜色 [1,5](@ref)
    # 灰度映射：障碍物深灰(100)，普通区域灰(200)，空白白(255)
    gray_areas = np.where((filled_mask == 1) & (obstacle_map == 0))
    dark_gray_areas = np.where((filled_mask == 1) & (obstacle_map == 1))

    base_img[gray_areas[0], gray_areas[1]] = [200, 200, 200]       # 灰色
    base_img[dark_gray_areas[0], dark_gray_areas[1]] = [100, 100, 100]  # 深灰色

    # 3. 计算每个目标点的百分比 [9](@ref)
    percentages = []
    positions = []

    for i in range(candidate_goals.shape[0]):
        # 定位目标点坐标
        y, x = np.argwhere(candidate_goals[i] == 1)[0]
        positions.append((x, y))

        # 计算圆形区域内0的占比
        x_min, x_max = max(0, x-k), min(w, x+k+1)
        y_min, y_max = max(0, y-k), min(h, y+k+1)

        region = filled_mask[y_min:y_max, x_min:x_max]
        y_grid, x_grid = np.ogrid[y_min:y_max, x_min:x_max]
        mask = ((x_grid - x)**2 + (y_grid - y)**2) <= k**2

        zero_pixels = np.sum((region == 0) & mask)
        total_pixels = np.sum(mask)
        percent = (zero_pixels / total_pixels) * 100 if total_pixels > 0 else 0
        percentages.append(percent)

    # 4. 绘制目标点和百分比文本（带防重叠）
    text_positions = []  # 记录已占用的文本位置

    for i, (x, y) in enumerate(positions):
        # 绘制红色目标点
        cv2.circle(base_img, (x, y), 2, (0, 0, 255), -1)

        # 确定文本颜色
        text_color = (0, 255, 0) if percentages[i] < threshold else (
            255, 0, 0)  # 绿/红

        # 动态计算文本位置（防重叠）
        offset_y = -15  # 默认在点上方
        for pos in text_positions:
            if abs(x - pos[0]) < 40 and abs(y - pos[1]) < 20:
                offset_y = 15  # 重叠则改到下方
                break

        text_pos = (x - 15, y + offset_y)
        text_positions.append(text_pos)

        # 绘制百分比文本
        text = f"{percentages[i]:.1f}%"
        cv2.putText(base_img, text, text_pos,
                    cv2.FONT_HERSHEY_SIMPLEX, 0.4, text_color, 1)

    # 5. 保存结果
    cv2.imwrite(output_path, base_img)
    return base_img


def get_candidate_goal_mask_and_replace_obs(depth, graph, graph2, obs, args, device):
    candidate_goal_mask = graph.process_depth(
        depth)  # candidate_goal_mask.shape=(num_candidate_goals,960, 1280) ,num_candidate_goals>=0
    if candidate_goal_mask.shape[0] == 0:
        candidate_goal_mask = graph2.process_depth(depth)
    expand_candidate_goal_mask = expand_masks(
        candidate_goal_mask)
    if expand_candidate_goal_mask is None:  # 当前视角没有候选探索方向。也不更新obs
        return np.zeros_like(candidate_goal_mask), obs
    num_candidate_goals = expand_candidate_goal_mask.shape[0]
    ds = args.env_frame_width // args.frame_width
    if ds != 1:
        expand_candidate_goal_mask = expand_candidate_goal_mask[:,
                                                                ds // 2::ds, ds // 2::ds]
    obs[0, 4:4+num_candidate_goals, :,
        :] = torch.from_numpy(expand_candidate_goal_mask).to(device)
    return candidate_goal_mask, obs


def is_stuck(position_history, current_mask_pos, window_size=5, threshold=0.5):
    position_history.append(current_mask_pos)
    if len(position_history) == window_size:
        pos_array = np.array(position_history)
        std_dev = np.std(pos_array, axis=0).mean()  # 计算坐标标准差
        return std_dev < threshold  # 阈值根据实际场景调整
    return False


def is_goal_blocked(obstacle_map, tmp_goal, k=8, thresh=0.45):
    """
    检查目标点周围半径为k的圆内障碍物占比是否超过阈值thresh

    参数:
        obstacle_map (np.ndarray or torch.Tensor): 障碍物地图，0表示自由空间，1表示障碍物
        tmp_goal (np.ndarray or torch.Tensor): 目标点地图，只有一个1，其余为0
        k (int): 圆的半径
        thresh (float): 障碍物占比阈值（0到1之间）

    返回:
        bool: 如果障碍物占比 > thresh，返回True；否则返回False
    """
    import torch

    # 判断输入是否为torch tensor并进行相应处理
    is_obstacle_tensor = torch.is_tensor(obstacle_map)
    is_goal_tensor = torch.is_tensor(tmp_goal)

    # 转换为numpy进行处理
    if is_obstacle_tensor:
        obstacle_np = obstacle_map.cpu().numpy()
    else:
        obstacle_np = obstacle_map

    if is_goal_tensor:
        tmp_goal_np = tmp_goal.cpu().numpy()
    else:
        tmp_goal_np = tmp_goal

    # 检查输入合法性
    if obstacle_np.shape != tmp_goal_np.shape:
        raise ValueError("obstacle_map和tmp_goal的shape必须相同")
    if np.sum(tmp_goal_np == 1) != 1:
        raise ValueError("tmp_goal中必须有且仅有一个1")

    # 找到tmp_goal中1的位置（y, x）
    y, x = np.argwhere(tmp_goal_np == 1)[0]
    height, width = obstacle_np.shape

    # 生成网格坐标
    yy, xx = np.mgrid[:height, :width]

    # 计算每个点到目标点的欧氏距离
    distances = np.sqrt((yy - y)**2 + (xx - x)**2)

    # 获取圆内的像素（距离 <= k）
    circle_mask = distances <= k

    # 计算圆内障碍物的占比
    circle_area = np.sum(circle_mask)
    if circle_area == 0:
        return False  # 圆内无像素（理论上不可能，除非k=0）

    obstacle_count = np.sum(obstacle_np[circle_mask] == 1)
    obstacle_ratio = obstacle_count / circle_area

    return obstacle_ratio > thresh


def get_Node_based_full_map(full_map, NodeDetector, planner_pose_inputs, args, infos):
    obstacle_mask = full_map[0, 0, :, :] > 0
    exp_mask = full_map[0, 1, :, :] > 0
    obs_exp = np.stack(
        [obstacle_mask.cpu().numpy(), exp_mask.cpu().numpy()], axis=0)
    trajectory_mask = infos[0]['trajectory_mask']
    agent_posion_mask = get_agent_position_mask(
        planner_pose_inputs, full_map.shape, args)
    node_map = NodeDetector.run(obs_exp, trajectory_mask, agent_posion_mask)
    if node_map.shape[0] != 0:
        node_ls = []
        for i in range(node_map.shape[0]):
            if is_path_reachable(obstacle_mask, exp_mask, agent_posion_mask, node_map[i]):
                node_ls.append(node_map[i])
        if len(node_ls) != 0:
            return np.stack(node_ls, axis=0)
        else:
            return None
    else:
        return None


def sample_escape_goal(obstacle_map, exp_map, history_position, initial_radius=20, max_radius=100):
    """
    在被卡位置附近采样可通行的脱困目标点

    参数:
        obstacle_map (np.ndarray): 障碍物地图，1=障碍物
        exp_map (np.ndarray): 可通行地图，1=可通行
        history_position (np.ndarray): 历史位置掩码，仅含一个1
        initial_radius (int): 初始搜索半径(像素)
        max_radius (int): 最大搜索半径(像素)

    返回:
        tuple: (y, x) 采样点坐标，若无有效点返回None
    """
    # 验证输入
    assert obstacle_map.shape == exp_map.shape == history_position.shape
    assert np.sum(history_position) == 1, "历史位置掩码应有且仅有一个1"

    # 定位被卡位置
    center_y, center_x = np.argwhere(history_position == 1)[0]
    height, width = obstacle_map.shape

    # 生成候选点坐标网格
    y_coords, x_coords = np.indices((height, width))

    # 计算各点到中心的欧氏距离
    distances = np.sqrt((y_coords - center_y)**2 + (x_coords - center_x)**2)

    current_radius = initial_radius
    while current_radius <= max_radius:
        # 创建搜索掩码 (圆形区域 + 距离约束)
        radius_mask = distances <= current_radius

        # 创建可通行掩码 (排除障碍物)
        valid_mask = (exp_map == 1) & (obstacle_map == 0)

        # 组合条件：在圆形区域内且可通行
        candidate_mask = radius_mask & valid_mask

        # 获取候选点坐标
        candidate_points = np.argwhere(candidate_mask)

        if len(candidate_points) > 0:
            # 随机选择一个候选点
            chosen_idx = random.randint(0, len(candidate_points)-1)
            return tuple(candidate_points[chosen_idx])

        # 扩大搜索半径 (指数增长避免无限循环)
        current_radius = min(int(current_radius * 1.5), max_radius)

    return None  # 未找到有效点


def check_consistent_positions(history_position, planner_pose_inputs, full_map, args, k=25, min_common=20):
    """
    检查双端队列中所有数组的值为1的位置是否相同

    参数:
        queue (deque): 双端队列，每个元素为 shape (height, width) 的 NumPy 数组，且仅有一个1
        k (int): 期望的队列长度阈值（如k=10）

    返回:
        bool: 
            - 如果队列长度 < k，返回 False
            - 如果队列长度 == k 且所有1的位置相同，返回 True
            - 否则返回 False
    """
    agent_position_mask = get_agent_position_mask(
        planner_pose_inputs, full_map.shape, args)
    history_position.append(agent_position_mask)
    # 检查队列长度
    # 检查队列长度
    if len(history_position) < k:
        return False

    # 统计每个位置出现的次数
    position_count = {}
    for array in history_position:
        # 获取1的位置坐标 (y, x) 并转为元组
        pos = tuple(np.argwhere(array == 1)[0])
        position_count[pos] = position_count.get(pos, 0) + 1

    # 检查是否有位置出现至少 min_common 次
    for count in position_count.values():
        if count >= min_common:
            return True

    return False


def is_path_reachable(obstacle_map, exp_map, agent_position_mask, tmp_goal_map):
    """
    判断从agent位置到目标位置是否存在可通行路径

    参数:
        obstacle_map (np.ndarray): 障碍物地图，形状为(height, width)，1表示障碍物，0表示未知区域
        exp_map (np.ndarray): 可通行区域地图，形状为(height, width)，1表示可通行区域，0表示未知区域
        agent_position_mask (np.ndarray): agent位置掩码，形状为(height, width)，只有一个1表示agent位置
        tmp_goal_map (np.ndarray): 目标地图，形状为(height, width)，只有一个1表示目标位置

    返回:
        bool: 如果存在可通行路径返回True，否则返回False
    """

    # 将所有参数转换为numpy数组
    def to_numpy(x):
        if hasattr(x, 'cpu'):  # PyTorch tensor
            return x.cpu().numpy()
        else:  # 已经是numpy数组
            return x

    obstacle_map = to_numpy(obstacle_map)
    exp_map = to_numpy(exp_map)
    agent_position_mask = to_numpy(agent_position_mask)
    tmp_goal_map = to_numpy(tmp_goal_map)

    # 检查输入的有效性
    if obstacle_map.shape != exp_map.shape or obstacle_map.shape != agent_position_mask.shape or obstacle_map.shape != tmp_goal_map.shape:
        raise ValueError("所有地图的形状必须相同")

    # 获取地图尺寸
    height, width = obstacle_map.shape

    # 生成最终的可通行地图
    # 规则：exp_map中为1的区域是可通行的，但如果obstacle_map和exp_map都是1，则该位置是障碍物
    traversable_map = np.zeros_like(obstacle_map, dtype=bool)

    # 可通行条件：exp_map为1且obstacle_map为0
    traversable_map = (exp_map == 1) & (obstacle_map == 0)

    # 找到agent位置
    agent_positions = np.where(agent_position_mask == 1)
    if len(agent_positions[0]) == 0:
        raise ValueError("在agent_position_mask中没有找到agent位置")
    if len(agent_positions[0]) > 1:
        raise ValueError("agent_position_mask中有多个1，应该只有一个")

    agent_x, agent_y = agent_positions[0][0], agent_positions[1][0]

    # 找到目标位置
    goal_positions = np.where(tmp_goal_map == 1)
    if len(goal_positions[0]) == 0:
        raise ValueError("在tmp_goal_map中没有找到目标位置")
    if len(goal_positions[0]) > 1:
        raise ValueError("tmp_goal_map中有多个1，应该只有一个")

    goal_x, goal_y = goal_positions[0][0], goal_positions[1][0]

    # 检查起点和终点是否都在可通行区域
    if not traversable_map[agent_x, agent_y]:
        return False  # agent位置不可通行

    if not traversable_map[goal_x, goal_y]:
        return False  # 目标位置不可通行

    # 如果起点和终点是同一个位置，返回True
    if agent_x == goal_x and agent_y == goal_y:
        return True

    # 使用BFS搜索路径
    visited = np.zeros_like(traversable_map, dtype=bool)
    queue = deque([(agent_x, agent_y)])
    visited[agent_x, agent_y] = True

    # 定义8个方向的移动（上下左右和四个对角线方向）
    directions = [(-1, -1), (-1, 0), (-1, 1), (0, -1),
                  (0, 1), (1, -1), (1, 0), (1, 1)]

    while queue:
        current_x, current_y = queue.popleft()

        # 检查所有相邻位置
        for dx, dy in directions:
            next_x, next_y = current_x + dx, current_y + dy

            # 检查边界条件
            if 0 <= next_x < height and 0 <= next_y < width:
                # 如果到达目标位置
                if next_x == goal_x and next_y == goal_y:
                    return True

                # 如果该位置可通行且未访问过
                if traversable_map[next_x, next_y] and not visited[next_x, next_y]:
                    visited[next_x, next_y] = True
                    queue.append((next_x, next_y))

    # 如果BFS结束后没有找到路径，返回False
    return False


class CandidateGoalIterator:
    def __init__(self, candidate_goals=None):
        """
        初始化候选目标迭代器
        :param candidate_goals: 形状为 (num_candidate, height, width) 的数组
        """
        self.set_candidate_goals(candidate_goals)

    def set_candidate_goals(self, candidate_goals):
        """
        设置新的候选目标数组并重置迭代状态
        :param candidate_goals: 新的候选目标数组
        """
        self.candidate_goals = candidate_goals
        self.current_index = 0  # 重置当前索引

    def next(self):
        """
        获取下一个候选目标通道
        :return: (height, width) 数组或 None
        """
        # 检查是否完成遍历[2](@ref)
        if self.candidate_goals is None or self.current_index >= len(self.candidate_goals):
            return None

        # 获取当前通道并更新索引
        channel = self.candidate_goals[self.current_index]
        self.current_index += 1
        return channel

    def get_remaining_goals(self):
        """
        获取剩余的所有候选目标通道
        :return: 剩余通道数组 (remaining_num, height, width) 或 None
        """
        # 检查是否有候选目标数组
        if self.candidate_goals is None:
            return None

        # 检查是否还有剩余通道
        if self.current_index >= len(self.candidate_goals):
            return None

        # 返回从当前索引开始的所有剩余通道
        remaining_channels = self.candidate_goals[self.current_index:]
        return remaining_channels


# def expand_semantic_mask(semantic_mask, up_scale=2, down_scale=1):
#     processed_mask = semantic_mask.cpu().numpy().copy()
#     num_scenes, num_channels, height, width = semantic_mask.shape

#     for scene_idx in range(num_scenes):
#         for class_idx in range(4, num_channels):
#             mask = processed_mask[scene_idx, class_idx]
#             count = np.sum(mask)

#             if 0 < count < 100:
#                 # Find original x boundaries
#                 ys, xs = np.where(mask)
#                 if len(xs) == 0:
#                     continue
#                 x_start, x_end = np.min(xs), np.max(xs)

#                 # Calculate new x boundaries
#                 width_original = x_end - x_start
#                 delta = int(round(0.05 * width_original))
#                 new_x_start = x_start + delta
#                 new_x_end = x_end - delta

#                 # Apply x-axis contraction
#                 new_mask = np.zeros_like(mask)
#                 new_x_start = max(0, new_x_start)
#                 new_x_end = min(width-1, new_x_end)
#                 if new_x_start <= new_x_end:
#                     new_mask[:, new_x_start:new_x_end +
#                              1] = mask[:, new_x_start:new_x_end+1]

#                 # Find y boundaries after contraction
#                 ys_new, _ = np.where(new_mask)
#                 if len(ys_new) == 0:
#                     processed_mask[scene_idx, class_idx] = new_mask
#                     continue

#                 y_start, y_end = np.min(ys_new), np.max(ys_new)
#                 h = y_end - y_start + 1

#                 # Calculate new y boundaries
#                 up_extend = int(round(up_scale * h))
#                 down_extend = int(round(down_scale * h))
#                 y_start_new = max(0, y_start - up_extend)
#                 y_end_new = min(height-1, y_end + down_extend)

#                 # Apply y-axis expansion
#                 final_mask = np.zeros_like(new_mask)
#                 if new_x_start <= new_x_end:
#                     final_mask[y_start_new:y_end_new +
#                                1, new_x_start:new_x_end+1] = 1

#                 processed_mask[scene_idx, class_idx] = final_mask

#     return torch.from_numpy(processed_mask).to(semantic_mask.device)


# def expand_semantic_channels(obs, target_threshold=400, start_channel=4, end_channel=9):
#     """
#     扩展语义通道中的小目标区域到指定的阈值大小

#     Args:
#         obs: shape为[1, 30, height, width]的tensor或numpy数组
#         target_threshold: 目标阈值，如果非零值个数小于此值则进行扩展
#         start_channel: 起始通道索引（包含）
#         end_channel: 结束通道索引（包含）

#     Returns:
#         处理后的数组，保持输入相同的shape和类型
#     """

#     # 判断输入类型并转换为numpy处理
#     is_tensor = torch.is_tensor(obs)
#     if is_tensor:
#         obs_np = obs.clone().cpu().numpy()
#         device = obs.device
#     else:
#         obs_np = obs.copy()

#     # 创建输出数组
#     output = obs_np.copy()

#     # 获取维度信息
#     batch_size, num_channels, height, width = obs_np.shape

#     # 验证输入shape
#     if batch_size != 1:
#         raise ValueError(f"期望batch_size为1，但得到{batch_size}")

#     # 处理指定通道范围（双闭区间）
#     for channel_idx in range(start_channel, end_channel + 1):
#         if channel_idx >= num_channels:
#             continue

#         # 获取当前通道的掩码
#         current_mask = obs_np[0, channel_idx]

#         # 检测非零值
#         nonzero_positions = np.argwhere(current_mask != 0)
#         num_nonzero = len(nonzero_positions)

#         # 如果没有非零值或者非零值个数已经超过阈值，跳过
#         if num_nonzero == 0 or num_nonzero >= target_threshold:
#             continue

#         print(
#             f"通道 {channel_idx}: 检测到 {num_nonzero} 个非零值，小于阈值 {target_threshold}，开始扩展...")

#         # 计算非零值的中心坐标
#         center_y = np.mean(nonzero_positions[:, 0])
#         center_x = np.mean(nonzero_positions[:, 1])
#         cy, cx = int(round(center_y)), int(round(center_x))

#         print(f"  中心坐标: ({cy}, {cx})")

#         # 计算边界限制的最大半径
#         top_dist = cy
#         bottom_dist = height - 1 - cy
#         left_dist = cx
#         right_dist = width - 1 - cx
#         max_radius = min(top_dist, bottom_dist, left_dist, right_dist)

#         # 计算理论目标半径
#         target_radius = sqrt(target_threshold / pi)

#         # 确定实际使用的半径
#         if max_radius >= target_radius:
#             final_radius = target_radius
#             print(f"  使用目标半径: {final_radius:.2f}")
#         else:
#             final_radius = max_radius
#             print(f"  受边界限制，使用最大半径: {final_radius:.2f}")

#         # 创建网格坐标
#         Y, X = np.ogrid[:height, :width]

#         # 计算每个点到中心的距离
#         dist = np.sqrt((X - cx)**2 + (Y - cy)**2)

#         # 创建圆形区域掩码
#         circle_mask = dist <= final_radius

#         # 计算实际扩展后的像素数
#         actual_pixels = np.sum(circle_mask)
#         print(f"  扩展后像素数: {actual_pixels}")

#         # 更新当前通道的掩码（保持原有的非零值）
#         new_mask = np.zeros_like(current_mask)
#         new_mask[circle_mask] = 1

#         # 将处理后的掩码更新到输出数组
#         output[0, channel_idx] = new_mask

#     # 转换回原始类型
#     if is_tensor:
#         return torch.from_numpy(output).to(device)
#     else:
#         return output
